# 训练机器，识别二维码
import warnings
from urllib import request
import cv2
import numpy as np
import sys
import os
import http.cookiejar
from utils import logging_method as log
import time
from entity import crawler_config
from utils import config_util

# 分割图片类型
draw = 0
cut = 1
# 标注好的切图存储路径
img_lib_path = config_util.root_path + "/img_lib"
captcha_path = config_util.root_path + "/img/cztl-web-captcha.jpeg"
test_times = 0
right_times = 0


def test_bound_result(num=10):
    """
    查看边界识别情况
    :param num:
    :return:
    """
    for i in range(num):
        count_correct_times(draw)


def test_accuracy(num=10):
    """
    用于测试二维码识别正确率，每次二维码出现，需手动输入正确值供参考
    :param num: 测试次数
    :return:
    """
    for i in range(num):
        count_correct_times()
    print("共测试%d次，正确率百分之%.2f" % (test_times, right_times/test_times*100))


def cut_and_mark():
    """
    切图并人工标注后保存
    :return: 二维码识别、计算后的值
    """
    load_captcha()
    img_dst, rects = deal_image(captcha_path)
    # 切图列表
    images = cut_rect_contours(img_dst, rects, cut)
    for image in images:
        timestamp = int(time.time() * 1e6)  # 为防止文件重名，使用时间戳命名文件名
        mark_img(image, timestamp)


def count_correct_times(cut_type=cut):
    """
    统计验证码字符总数，和识别正确的次数
    :return: 二维码识别、计算后的值
    """
    load_captcha()
    img_dst, rects = deal_image(captcha_path)
    # 切图列表
    images = cut_rect_contours(img_dst, rects, cut_type)
    if cut_type == draw:
        show_img("bound", images[0])
        return
    id_label_map, model = train_machine()
    global test_times
    global right_times
    for image in images:
        sample = image.reshape((1, 900)).astype(np.float32)
        ret, results, neighbours, distances = model.findNearest(sample, k=3)
        label_id = int(results[0, 0])
        label = id_label_map[label_id]
        cv2.imshow("image", image)
        key = cv2.waitKey(0)
        if key == 27:
            sys.exit()
        if key == 13:
            return
        correct_char = chr(key)
        print("您输入的key是：%s,机器识别的key是：%s" % (correct_char, label))
        test_times += 1
        if correct_char == "*" and label == "++":
            right_times += 1
            continue
        if label == correct_char:
            right_times += 1


def get_captcha_val():
    """
    获取验证码
    :return: 二维码识别、计算后的值
    """
    img_dst, rects = deal_image(captcha_path)
    # 切图返回计算验证码结果
    return cut_and_calc(img_dst, rects)


def deal_image(image_path):
    """
    图形处理，灰度化、二值化、透视拉伸
    :return: 返回处理过的二值化图片，字符的矩形轮廓
    """
    # 打开图片
    img = cv2.imread(image_path)
    # 灰度处理 或者 分离通道
    # img_gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    b, img_gray, r = cv2.split(img)
    # show_img("灰度化", img_gray)
    # 二值化
    ret, img_inv = cv2.threshold(img_gray, 80, 255, cv2.THRESH_BINARY_INV)
    # show_img("二值化", img_inv)
    # 透视拉伸
    img_dst = img_perspective(img_inv)
    # show_img("拉伸", img_dst)
    # 查找轮廓
    rects = get_rect_contours(img_dst)
    return img_dst, rects


def show_img(title, source):
    """
    展示图片，阻塞式，点击关闭才会走后续程序
    :param title: 标题
    :param source: 图片
    :return:
    """
    cv2.imshow(title, source)
    if cv2.waitKey(0) == 27:
        cv2.destroyWindow("test1")


def load_captcha():
    """
    请求下载验证码图片
    :return:
    """
    warnings.warn("过时的方法", DeprecationWarning)
    # 设置一个cookie处理器，它负责从服务器下载cookie到本地，并且在发送请求时带上本地的cookie
    cj = http.cookiejar.CookieJar()
    cookie_support = request.HTTPCookieProcessor(cj)
    opener = request.build_opener(cookie_support, request.HTTPHandler)
    request.install_opener(opener)
    req = request.Request(crawler_config.get_conf("captcha_url"))

    res = request.urlopen(req).read()
    # 解析cookie
    cookie_text = ''
    for item in cj:
        cookie_text = cookie_text + item.name + '=' + item.value + '&'
    cookie_text = cookie_text[0:-1]
    with open(captcha_path, "wb") as image:
        image.write(res)


def img_perspective(img):
    """
    图片透视拉伸
    :param img: 源图片
    :return: 拉伸后的图片
    """
    pos1 = np.float32([[0, 0], [135, 0], [30, 60], [160, 60]])
    pos2 = np.float32([[25, 0], [160, 0], [30, 60], [160, 60]])
    mm = cv2.getPerspectiveTransform(pos1, pos2)
    return cv2.warpPerspective(img, mm, (160, 60))


def draw_rect_bound(img_dst, x, y, w, h):
    """
    根据二值化后的图片，画出矩形边界，一般在调整图片边界时用
    :param img_dst: 二值化结果
    :param x:
    :param y:
    :param w: 宽度
    :param h: 高度
    :return: 返回画好矩形边界的图
    """
    return cv2.rectangle(img_dst, (x, y), (x + w, y + h), (255, 255, 255), thickness=1)


def cut_char_img(img_dst, x, y, w, h):
    """
    根据二值化后的图片，进行矩形切图并保存
    :param img_dst: 二值化后的图片
    :param x: 起始x坐标
    :param y: 起始y坐标
    :param w: 宽
    :param h: 高
    :return: 切割后的图片
    """
    box = np.int0([[x, y], [x + w, y], [x + w, y + h], [x, y + h]])
    cv2.drawContours(img_dst, [box], 0, (0, 0, 255), 2)
    roi = img_dst[box[0][1]:box[3][1], box[0][0]:box[1][0]]
    roi_std = cv2.resize(roi, (30, 30))  # 将字符图片统一调整为30x30的图片大小
    return roi_std


def cut_rect(img_dst, x, y, w, h, cut_type=draw):
    """
    根据坐标分割矩形图片，cut_type用于判断是画框还是切图
    :param img_dst: 源图片
    :param x: x坐标
    :param y: y坐标
    :param w: 宽
    :param h: 高
    :param cut_type: 画线/切图
    :return: 处理后的图片
    """
    if cut_type == draw:
        return draw_rect_bound(img_dst, x, y, w, h)
    elif cut_type == cut:
        return cut_char_img(img_dst, x, y, w, h)


def cut_and_analyse(img_dst, x, y, w, h):
    """
    返回从单个切图识别的字符
    :param img_dst:
    :param x:
    :param y:
    :param w:
    :param h:
    :return:
    """
    img = cut_char_img(img_dst, x, y, w, h)
    id_label_map, model = train_machine()
    sample = img.reshape((1, 900)).astype(np.float32)
    ret, results, neighbours, distances = model.findNearest(sample, k=3)
    label_id = int(results[0, 0])
    label = id_label_map[label_id]
    return label


def get_rect_contours(img_dst):
    """
    获取矩形边界列表，按照x坐标从左向右排序
    :param img_dst:
    :return:
    """
    contours, hierarchy = cv2.findContours(img_dst, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    rects = []
    for contour in contours:
        x, y, w, h = cv2.boundingRect(contour)
        # 排除问号
        if w < 10 and x > 100 or x > 110 and h > 10:
            continue
        # 排除等号
        if 2 < w / h < 6 and x > 90:
            continue
        # 两个字符粘在一起，通常是第二个数字和运算符粘连 或第二个数字和等号粘连 或第一个数字和运算符粘连
        if 28 < w < 50:
            w = 20
            if x < 80:
                rects.append((x + 20, y, w, h))
        # 三个字符粘连，出现在运算符、第二个数字、等号之间 或第一个数字、运算符、第二个数字之间
        if w > 50:
            rects.append((x, y, 20, h))
            rects.append((x + 20, y, 20, h))
            if x < 50:
                rects.append((x + 40, y, w - 40, h))
            continue
        # '*' 被划分太细的情况放弃
        if w < 10 and h < 10:
            continue
        # 矩形切图
        rects.append((x, y, w, h))
    rects.sort(key=None, reverse=False)
    return rects


def cut_rect_contours(img_dst, rects, cut_type=draw):
    """
    分割多个矩形图片，仅画框或切图
    :param img_dst:
    :param rects:
    :param cut_type:
    :return: 返回切图列表
    """
    images = []
    for rect in rects:
        images.append(cut_rect(img_dst, rect[0], rect[1], rect[2], rect[3], cut_type))
    return images


def cut_and_calc(img_dst, rects):
    """
    识别验证码，并返回计算结果
    :param img_dst:
    :param rects:
    :return:
    """
    chars = []
    for rect in rects:
        if len(chars) == 3:
            log.error("识别到的字符串多于3个")
            return
        chars.append(cut_and_analyse(img_dst, rect[0], rect[1], rect[2], rect[3]))
        # show_img("img", img)
    result = -1
    global test_times
    test_times = test_times + 1
    try:
        num1 = int(chars[0])
        operator = chars[1]
        num2 = int(chars[2])
        print("识别到的验证码为：%s%s%s" % (chars[0], chars[1], chars[2]), end=" ")
        if operator == "+":
            result = num1 + num2
        elif operator == "-":
            result = num1 - num2
        elif operator == "++":
            result = num1 * num2
        elif operator == "/":
            result = num1 / num2
        else:
            log.error("运算符识别错误")
        print("=", result)
        return result
    except:
        log.error("数字识别错误")


def mark_img(roi, timestamp):
    """
    人工标注切图
    :param roi:
    :param timestamp:
    :return:
    """
    print("PS:对每张切图输入对应的字符（用于标记切图），回车跳过当前切图，点击关闭退出人工标记切图")
    cv2.imshow("image", roi)
    key = cv2.waitKey(0)
    if key == 27:
        sys.exit()
    if key == 13:
        return
    char = chr(key)
    print("您输入的key是：", char)
    if char == "*":
        char = "++"
    filename = "%s/%s_%s.jpg" % (img_lib_path, timestamp, char)
    cv2.imwrite(filename, roi)


def train_machine():
    """
    机器训练
    :return: id_label_map, model
    """
    # TODO 后续可尝试将返回值缓存和持久化
    filenames = os.listdir(img_lib_path)
    samples = np.empty((0, 900))
    labels = []
    for filename in filenames:
        filepath = "%s/%s" % (img_lib_path, filename)
        label = filename.split(".")[0].split("_")[-1]
        labels.append(label)
        im = cv2.imread(filepath, cv2.IMREAD_GRAYSCALE)
        roi_std = cv2.resize(im, (30, 30))
        sample = roi_std.reshape((1, 900)).astype(np.float32)
        samples = np.append(samples, sample, 0)
    samples = samples.astype(np.float32)
    unique_labels = list(set(labels))
    unique_ids = list(range(len(unique_labels)))
    label_id_map = dict(zip(unique_labels, unique_ids))
    id_label_map = dict(zip(unique_ids, unique_labels))
    label_ids = list(map(lambda x: label_id_map[x], labels))
    label_ids = np.array(label_ids).reshape((-1, 1)).astype(np.float32)
    model = cv2.ml.KNearest_create()
    model.train(samples, cv2.ml.ROW_SAMPLE, label_ids)
    return id_label_map, model




